%This script runs the simulations used for Figures B.1 and B.2. 
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024
%Input: model_data_fgkk.mat
%Output: Figures B.1 and B.2

%% Preliminaries

clear all;
close all;

%Define vector of large varieties in our sample
trade_target = 0.9; %baseline has 0.9

%Simulation parameters
Nsimul = 1500; %Number of simulations (baseline 1500)

%% Import Data
%set local paths
function_path = "..\Functions\";
addpath(function_path,'-frozen');
set_local_paths;

%Import data and compute researcher predictions
simulation_preliminary_steps;

%% Create IV matrices 

%Auxiliary matrices for control and welfare adjustments
%load auxiliary adjustment vectors
load(save_share_path + "simulation_shares_t" + trade_target + ".mat", 'adjGE_wNC', 'adjGE_wMCa', 'adjGE_wMC', 'adjGE_sMC')
  
    adjGE_wNC_mat= spdiags(adjGE_wNC, 0 , Nobs, Nobs );
    share_wNC_dW = adjGE_wNC_mat*shareMOD_dW;
    clearvars adjGE_wNC_mat 

    adjGE_wMCa_mat= spdiags(adjGE_wMCa, 0 , Nobs, Nobs );
    share_wMCa_dWa = adjGE_wMCa_mat*shareMOD_dW;
    share_wMCa_dW = share_wMCa_dWa - Ca*(MCa_term1*share_wMCa_dWa);
    clearvars adjGE_wMCa_mat share_wMCa_dWa 

    adjGE_wMC_mat= spdiags(adjGE_wMC, 0 , Nobs, Nobs );
    share_wMC_dWa = adjGE_wMC_mat*shareMOD_dW;
    share_wMC_dW = share_wMC_dWa - C*(MC_term1*share_wMC_dWa);
    clearvars adjGE_wMC_mat share_wMC_dWa 
  
    adjGE_sMC_mat= spdiags(adjGE_sMC, 0 , Nobs, Nobs );
    share_sMC_dWa = adjGE_sMC_mat*shareMOD_dW;
    share_sMC_dW = share_sMC_dWa - C*(MC_term1*share_sMC_dWa);
    clearvars share_sMC_dWa adjGE_sMC_mat 

%% Compute equilibrium given different draws of parameters

%Shock parameters
avg_sh = .02; %average tariff shock
std_sh = .06; %st dev of tariff shock
std_ep = .06; %st dev of other shocks

%Specify misspecification vector 
gamma_grid_plot = [-1.5, -1, -.5, -.25, 0, .25, .5,  1, 1.5]';
Ngamma = length(gamma_grid_plot);

%Parameters for simulation draws
sd_shifters_grid   = std_sh*ones(1*Ngamma,1);
mean_shifters_grid = avg_sh*ones(1*Ngamma,1);

%parameters for draws of shocks to other parameters from normal distribution
mean_deta_grid     = 0*ones(1*Ngamma,1);
sd_da_grid         = std_ep*ones(1*Ngamma,1);
sd_dzstar_grid     = std_ep*ones(1*Ngamma,1);
sd_dastar_grid     = std_ep*ones(1*Ngamma,1);

%Parameters controlling true DGP
sigma_grid      = sigma     *ones(1*Ngamma,1);       
omega_star_grid = omega_star*ones(1*Ngamma,1);  
eta_grid        = eta       *ones(1*Ngamma,1);        
kappa_grid      = kappa     *ones(1*Ngamma,1);       
sigma_star_grid = sigma_star*ones(1*Ngamma,1);
rho_exp_grid    = 1         *ones(1*Ngamma,1);
rho_imp_grid    = 1         *ones(1*Ngamma,1);
gammaM_grid     = 1         *ones(1*Ngamma,1);
gammaX_grid     = 1         *ones(1*Ngamma,1);
gammaQ_grid     = 1         *ones(1*Ngamma,1);

gamma_mean_px_grid = gamma_grid_plot;
gamma_mean_pm_grid = gamma_grid_plot;
gamma_mean_rm_grid = gamma_grid_plot;

gamma_sd_px_grid = zeros(size(gamma_mean_px_grid));
gamma_sd_pm_grid = zeros(size(gamma_mean_pm_grid));
gamma_sd_rm_grid = zeros(size(gamma_mean_rm_grid));

Npar = Ngamma;

Npar
ests = zeros(22, Nsimul, Npar);

for j=1:Npar
    tic
    display('------------- running new j -----------------')
    j
    
    %shock process
    sd_shifters  = sd_shifters_grid(j);
    mean_shifters = mean_shifters_grid(j);
    adj = Nobs*(mean_shifters/sd_shifters);

    mean_deta = mean_deta_grid(j);
    sd_da     = sd_da_grid(j);
    sd_dzstar = sd_dzstar_grid(j);
    sd_dastar = sd_dastar_grid(j);
    
    %Generate true DGP
    omega_starDGP = omega_star_grid(j);    
    sigmaDGP = sigma_grid(j);           
    etaDGP = eta_grid(j);               
    kappaDGP = kappa_grid(j);           
    sigma_starDGP = sigma_star_grid(j); 
    rhoMDGP = rho_imp_grid(j) ;
    rhoXDGP = rho_exp_grid(j) ;
    gammaMDGP = gammaM_grid(j) ;
    gammaXDGP = gammaX_grid(j) ;
    gammaQDGP = gammaQ_grid(j) ;    
    
    gamma_n = (indM==1).*( normrnd(gamma_mean_pm_grid(j), gamma_sd_pm_grid(j), Nobs, 1) )... 
            + (indT==1).*( normrnd(gamma_mean_rm_grid(j), gamma_sd_rm_grid(j), Nobs, 1) )...
            + (indX==1).*( normrnd(gamma_mean_px_grid(j), gamma_sd_px_grid(j), Nobs, 1) );
    gammaDGP = spdiags(1+ gamma_n, 0, Nobs,Nobs);

     [delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, E_s_DGP0, F_DGP0] = ... 
        invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
        Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, gammaMDGP, gammaXDGP,  tol_parm);
 
    [p_s_IV, PM_s_IV, ...
         rhoDGP_qm_ig_tau_ig, rhoDGP_qm_ig_taustar_ig, rhoDGP_qm_ig_a_ig, rhoDGP_qm_ig_zstar_ig, rhoDGP_qm_ig_astar_ig, rhoDGP_qm_ig_delta_ig, ...
         rhoDGP_pm_ig_tau_ig, rhoDGP_pm_ig_taustar_ig, rhoDGP_pm_ig_a_ig, rhoDGP_pm_ig_zstar_ig, rhoDGP_pm_ig_astar_ig, rhoDGP_pm_ig_delta_ig, ...
         rhoDGP_px_ig_tau_ig, rhoDGP_px_ig_taustar_ig, rhoDGP_px_ig_a_ig, rhoDGP_px_ig_zstar_ig, rhoDGP_px_ig_astar_ig, rhoDGP_px_ig_delta_ig, ...
         rhoDGP_pstar_ig_tau_ig, rhoDGP_pstar_ig_taustar_ig, rhoDGP_pstar_ig_a_ig, rhoDGP_pstar_ig_zstar_ig, rhoDGP_pstar_ig_astar_ig, rhoDGP_pstar_ig_delta_ig, ...
         rhoDGP_rm_ig_tau_ig, rhoDGP_rm_ig_taustar_ig, rhoDGP_rm_ig_a_ig, rhoDGP_rm_ig_zstar_ig, rhoDGP_rm_ig_astar_ig, rhoDGP_rm_ig_delta_ig] ...
         = compute_equilibrium_foa_full( delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, ... 
             tau_ig_0, tau_star_ig_0, m_ig_0, E_s_DGP0, p_s_0, Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, nu, epsilon, gammaQDGP, gammaXDGP, gammaMDGP, rhoXDGP, rhoMDGP, tol_conv, adj_inner_g0, adj_outter_g0, ... 
             ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig ); 

    shareDGP_pm      = [rhoDGP_pm_ig_tau_ig    , rhoDGP_pm_ig_taustar_ig   ];
    shareDGP_rm      = [rhoDGP_rm_ig_tau_ig    , rhoDGP_rm_ig_taustar_ig   ];
    shareDGP_px      = [rhoDGP_px_ig_tau_ig    , rhoDGP_px_ig_taustar_ig   ];
    shareDGP_dW = gammaDGP*[shareDGP_pm; shareDGP_rm; shareDGP_px];
  
    shareDGP_dW_da     = [rhoDGP_pm_ig_a_ig    ; rhoDGP_rm_ig_a_ig    ; rhoDGP_px_ig_a_ig    ];
    shareDGP_dW_dzstar = [rhoDGP_pm_ig_zstar_ig; rhoDGP_rm_ig_zstar_ig; rhoDGP_px_ig_zstar_ig];
    shareDGP_dW_dastar = [rhoDGP_pm_ig_astar_ig; rhoDGP_rm_ig_astar_ig; rhoDGP_px_ig_astar_ig];
    
    clear    rhoDGP* shareDGP_p* shareDGP_r* shareDGP_q* delta_ig_DGP0 a_star_ig_DGP0 z_star_ig_DGP0 a_ig_DGP0 aM_g_DGP0 barL_rs_DGP0 Z_rs_DGP0
             
    toc
    tic
    %Simulation across draws of shocks 
    parfor n=1:Nsimul

    % simulated data
        %simulated shifts
        dtau_sim = normrnd(mean_shifters, sd_shifters, NMs, 1);
        dtaustar_sim = normrnd(mean_shifters, sd_shifters, NXs, 1);  
        shifters = [dtau_sim; dtaustar_sim];
        shiftersIV = (shifters - mean(shifters))/std(shifters);
        
        %simulated shocks
        da_sim     = normrnd(mean_deta, sd_da    , NMs, 1);
        dzstar_sim = normrnd(mean_deta, sd_dzstar, NMs, 1);
        dastar_sim = normrnd(mean_deta, sd_dastar, NXs, 1);
        detaDGP = shareDGP_dW_da*da_sim + shareDGP_dW_dzstar*dzstar_sim + shareDGP_dW_dastar*dastar_sim ;

        %Simulated predictions and outcomes with true parameters
        dx     = shareMOD_dW*shifters;
        dxstar = shareDGP_dW*shifters ;
        dy = dxstar + detaDGP;
        dyt_NC = dy - dx;
        dyn_NC = dyt_NC(sample_naive);

        % Welfare statistics
        dW = ww_dW'*dx;    
        dWstar = ww_dW'*dxstar;
        Delta_W = dWstar - dW;
       
    % Tests with the true parameters
        
        %correlation-based tests
        correl_pred = corr([dy, dx ]);
        correl_pred = correl_pred(2,1);
        MSE_pred = mean( dyt_NC.^2 );
        [bt_slope,~,~,~,stats_slope] = regress(dy, [ones(Nobs,1), dx]);
        R2_pred = stats_slope(1);
        
        %naive -- no GE
        z = shareMOD_dW*shiftersIV;

        zw_wNC = adj*adjGE_wNC.*z; 
                
        zw_wMCa = adj*adjGE_wMCa.*z;
        zw_wMCa = zw_wMCa - Ca*(MCa_term1*zw_wMCa) ;
        
        zw_wMC = adj*adjGE_wMC.*z;
        zw_wMC = zw_wMC - C*(MC_term1*zw_wMC) ;

        zw_sMC = adj*adjGE_sMC.*z;
        zw_sMC = zw_sMC - C*(MC_term1*zw_sMC) ;

        [bt_zw_wNC , se_zw_wNC , rj_zw_wNC , r0_zw_wNC ] = implement_test(dyt_NC, zw_wNC , adj*share_wNC_dW , shiftersIV, critical, 0);       
        [bt_zw_wMCa, se_zw_wMCa, rj_zw_wMCa, r0_zw_wMCa] = implement_test(dyt_NC, zw_wMCa, adj*share_wMCa_dW, shiftersIV, critical, 0);   
        [bt_zw_wMC , se_zw_wMC , rj_zw_wMC , r0_zw_wMC ] = implement_test(dyt_NC, zw_wMC , adj*share_wMC_dW , shiftersIV, critical, 0);   
        [bt_zw_sMC , se_zw_sMC , rj_zw_sMC , r0_zw_sMC ] = implement_test(dyt_NC, zw_sMC , adj*share_sMC_dW , shiftersIV, critical, 0);       

        %Output
       est_n = [dW, dWstar, Delta_W, correl_pred, R2_pred, MSE_pred, ...
               bt_zw_wNC, bt_zw_wMCa, bt_zw_wMC, bt_zw_sMC, ...
               se_zw_wNC, se_zw_wMCa, se_zw_wMC, se_zw_sMC, ...
               rj_zw_wNC, rj_zw_wMCa, rj_zw_wMC, rj_zw_sMC, ...
               r0_zw_wNC, r0_zw_wMCa, r0_zw_wMC, r0_zw_sMC];
    
    % save outcomes for step n
        ests(:,n,j) = est_n';
        
    end
    toc
    save(save_simulation_path + "output_Figs_B1_B2.mat", 'ests', 'tradestat_sample', 'gamma_mean_px_grid','gamma_mean_pm_grid', 'gamma_mean_rm_grid', 'Npar', 'Ngamma', 'trade_target');
end

%% Report results

load(save_simulation_path + "output_Figs_B1_B2.mat")

ests_stat=[];
for j = 1:Npar
    ests_j = ests(:,:,j);
    
     estimates_j = ests_j(1:10,:);
     SE   = [zeros(6,1); mean(ests_j(11:14,:),2)];
     rej  = [zeros(6,1); mean(ests_j(15:18,:),2)];
     rej0 = [zeros(6,1); mean(ests_j(19:22,:),2)];

    ests_stat(:,:,j)     = [mean(estimates_j,2), std(estimates_j,1,2), SE, rej0, rej];   
end

%% Figure B1

title_size = 14;
label_size = 20;
labely_size = 16;
legend_size = 18;
marker_size = 10;
axes_size = 18;

w_min = -.06;
w_max = .06;

ACDIV_wNC  = 7;
ACDIV_wMCa = 8;
ACDIV_wMC  = 9;
ACDIV_sMC  = 10;

%Average estimated coefficient
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_avg = tiledlayout(1,1);

j0 = 1;
jf = Ngamma;
DeltaW = squeeze(ests_stat(3,1,j0:jf));

beta_IV_wNC  =  squeeze(ests_stat(ACDIV_wNC ,1,j0:jf));
beta_IV_wMCa =  squeeze(ests_stat(ACDIV_wMCa,1,j0:jf));
beta_IV_wMC  =  squeeze(ests_stat(ACDIV_wMC ,1,j0:jf));
beta_IV_sMC  =  squeeze(ests_stat(ACDIV_sMC ,1,j0:jf));
x_axis = DeltaW;

ymin = -0.06;
ymax = 0.06;
xmin = w_min;
xmax = w_max;
%line_45d = xmin:.05:xmax;

nexttile
plot(x_axis, beta_IV_wNC, 'Marker','d','MarkerFaceColor',[0.6 0.6 0.6], 'MarkerSize', marker_size, 'Color', [0.6 0.6 0.6])
hold on 
plot(x_axis, beta_IV_wMCa, 'Marker','s','MarkerFaceColor',[0.3 0.3 0.3], 'MarkerSize', marker_size, 'Color', [0.3 0.3 0.3])
hold on 
plot(x_axis, beta_IV_wMC, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
%plot(line_45d, line_45d, '-','Color', 'black', 'LineWidth', 1)
%hold on
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
legend('No controls',  'Restricted set of controls', 'Baseline controls', 'Location', 'southeast', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)


%Rejection rate
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_rej = tiledlayout(1,1);

rej0_IV_wNC =  squeeze(ests_stat(ACDIV_wNC ,4,j0:jf));
rej0_IV_wMCa=  squeeze(ests_stat(ACDIV_wMCa,4,j0:jf));
rej0_IV_wMC =  squeeze(ests_stat(ACDIV_wMC ,4,j0:jf));
rej0_IV_sMC =  squeeze(ests_stat(ACDIV_sMC ,4,j0:jf));
x_axis = DeltaW;

ymin = 0 ;
ymax = 1;
xmin = w_min;
xmax = w_max;

nexttile
plot(x_axis, rej0_IV_wNC , 'Marker','d','MarkerFaceColor',[0.6 0.6 0.6], 'MarkerSize', marker_size, 'Color', [0.6 0.6 0.6])
hold on 
plot(x_axis, rej0_IV_wMCa, 'Marker','s','MarkerFaceColor',[0.3 0.3 0.3] , 'MarkerSize', marker_size, 'Color', [0.3 0.3 0.3])
hold on 
plot(x_axis, rej0_IV_wMC , 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
yticks([0 .2 .4 .6 .8 1])
legend('No controls',  'Restricted set of controls', 'Baseline controls', 'Location', 'north', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

saveas(plot_lines_avg, graph_path+ "Fig_B1b.png")
saveas(plot_lines_rej, graph_path+ "Fig_B1a.png")

saveas(plot_lines_avg, graph_path+ "Fig_B1b.eps", 'epsc');
saveas(plot_lines_rej, graph_path+ "Fig_B1a.eps", 'epsc');

%% Figure B2

%Average estimated coefficient
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_avg = tiledlayout(1,1);
ymin =  w_min ;
ymax =  w_max;

j0 = 1;
jf = Ngamma;
DeltaW = squeeze(ests_stat(3,1,j0:jf));

xmin = w_min;
xmax = w_max;
%line_45d = xmin:.05:xmax;

nexttile
plot(x_axis, beta_IV_sMC, 'Marker','d','MarkerFaceColor',[0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, beta_IV_wMC, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
%plot(line_45d, line_45d, '-','Color', 'black', 'LineWidth', 1)
%hold on
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
legend('prefereed IV, equal weights', 'prefereed IV', 'Location', 'southeast', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

%Rejection rate
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_rej = tiledlayout(1,1);

ymin = 0 ;
ymax = 1;

nexttile
plot(x_axis, rej0_IV_sMC, 'Marker','d','MarkerFaceColor',[0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, rej0_IV_wMC, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
yticks([0 .2 .4 .6 .8 1])
legend('prefereed IV, equal weights', 'prefereed IV', 'Location', 'north', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

saveas(plot_lines_avg, graph_path+ "Fig_B2b.png")
saveas(plot_lines_rej, graph_path+ "Fig_B2a.png")

saveas(plot_lines_avg, graph_path+ "Fig_B2b.eps", 'epsc');
saveas(plot_lines_rej, graph_path+ "Fig_B2a.eps", 'epsc');
